{
 "cells": [
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-01T07:30:13.061457Z",
     "start_time": "2024-09-01T07:30:12.873060Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import os\n",
    "import json\n",
    "import pandas as pd\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from graphs import get_nx_graph, verbalize_neighbors_triples_from_triples, verbalize_neighbors_triples_from_graph\n",
    "from models import KnowledgeGraphLLM"
   ],
   "id": "b2ed867e157de175",
   "outputs": [],
   "execution_count": 47
  },
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-09-01T07:26:07.031056Z",
     "start_time": "2024-09-01T07:26:07.027119Z"
    }
   },
   "source": [
    "# load extracted triples\n",
    "candidate_triples = []\n",
    "for line in open(f'output/test/step-01.jsonl', 'r'):\n",
    "    t = json.loads(line)\n",
    "    candidate_triples.append((t['s'], t['p'], t['o']))\n",
    "print(f'Loaded {len(candidate_triples)} triples.')\n",
    "print(candidate_triples[:5])"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 4 triples.\n",
      "[('OCR post-correction', 'Compare', 'spelling correction'), ('OCR post-correction', 'Part-of', 'sequence-to-sequence model'), ('Neural network models', 'Compare', 'Multi-task learning'), ('Neural network models', 'Evaluate-for', 'Shared layers')]\n"
     ]
    }
   ],
   "execution_count": 27
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-01T07:26:16.883041Z",
     "start_time": "2024-09-01T07:26:16.867503Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# loading human-selected concepts\n",
    "id_2_concept = {i: str(c['concept']) for i, c in\n",
    "                pd.read_csv('data/refined_concepts.tsv', sep='|', header=None,\n",
    "                            names=['id', 'concept'], index_col=0).iterrows()}\n",
    "concept_2_id = {v: k for k, v in id_2_concept.items()}\n",
    "print(id_2_concept[1])"
   ],
   "id": "4d7084593f4a7894",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "spelling correction\n"
     ]
    }
   ],
   "execution_count": 28
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-01T07:26:22.480160Z",
     "start_time": "2024-09-01T07:26:22.476369Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# load text data\n",
    "data = json.load(open('data/concept_abstracts_sample.json', 'r'))"
   ],
   "id": "d34192c5a3770125",
   "outputs": [],
   "execution_count": 29
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-01T07:26:23.408079Z",
     "start_time": "2024-09-01T07:26:23.404799Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# load relation types\n",
    "relation_def = json.load(open('data/relation_types.json'))\n",
    "relation_types = list(relation_def.keys())\n",
    "relation_2_id = {v: k for k, v in enumerate(relation_types)}\n",
    "id_2_relation = {k: v for k, v in enumerate(relation_types)}"
   ],
   "id": "80451a6323286be",
   "outputs": [],
   "execution_count": 30
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-01T07:26:24.720609Z",
     "start_time": "2024-09-01T07:26:24.714486Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# build the prerequisite-of graph\n",
    "prerequisite_of_triples = []\n",
    "with open('data/prerequisite-of_graph.tsv', 'r') as f:\n",
    "    for line in f:\n",
    "        s, p, o = line.strip().split('\\t')\n",
    "        prerequisite_of_triples.append((str(s), str(p), str(o)))\n",
    "\n",
    "prerequisite_of_graph = get_nx_graph(prerequisite_of_triples, concept_2_id, relation_2_id)\n"
   ],
   "id": "79764d75081ef04b",
   "outputs": [],
   "execution_count": 31
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-01T07:26:28.108184Z",
     "start_time": "2024-09-01T07:26:28.105680Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# initialize the prompt template\n",
    "prompt_template_txt = open(\"prompts/prompt_fusion.txt\").read()\n",
    "\n",
    "prompt_template = ChatPromptTemplate.from_messages([\n",
    "    (\"system\", \"You are a knowledge graph builder.\"),\n",
    "    (\"user\", prompt_template_txt)\n",
    "])"
   ],
   "id": "fe1911d8956459",
   "outputs": [],
   "execution_count": 33
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-01T07:26:34.721416Z",
     "start_time": "2024-09-01T07:26:34.718989Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# verbalize the candidate triples\n",
    "candidate_concept = \"spelling correction\"\n",
    "candidate_subgraph = verbalize_neighbors_triples_from_triples(candidate_triples, candidate_concept)\n",
    "\n",
    "print(f'Candidate subgraph of {candidate_concept}: \\n', candidate_subgraph)"
   ],
   "id": "5d6976e6b7ed85b1",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Candidate subgraph of spelling correction: \n",
      " (OCR post-correction,Compare,spelling correction)\n",
      "\n"
     ]
    }
   ],
   "execution_count": 34
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-01T07:28:58.386054Z",
     "start_time": "2024-09-01T07:28:58.383347Z"
    }
   },
   "cell_type": "code",
   "source": [
    "prerequisite_of_graph_subgraph = verbalize_neighbors_triples_from_graph(\n",
    "        prerequisite_of_graph, candidate_concept, concept_2_id, id_2_concept, mode='bidirectional')\n",
    "print(f'Prerequisite-of subgraph of {candidate_concept}: \\n', prerequisite_of_graph_subgraph)"
   ],
   "id": "a37a7f73d28d4df3",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prerequisite-of subgraph of spelling correction: \n",
      " (spelling correction,Is-a-Prerequisite-of,linguistics basics)\n",
      "(spelling correction,Is-a-Prerequisite-of,chinese nlp)\n",
      "(spelling correction,Is-a-Prerequisite-of,natural language processing intro)\n",
      "\n"
     ]
    }
   ],
   "execution_count": 44
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-01T07:28:59.891588Z",
     "start_time": "2024-09-01T07:28:59.888282Z"
    }
   },
   "cell_type": "code",
   "source": [
    "abstracts = ' '.join(\n",
    "    data[candidate_concept]['abstracts']) if candidate_concept in data else ''\n"
   ],
   "id": "96698b1ede02e4ef",
   "outputs": [],
   "execution_count": 45
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-01T07:30:19.470291Z",
     "start_time": "2024-09-01T07:30:19.461170Z"
    }
   },
   "cell_type": "code",
   "source": [
    "os.environ[\"OPENAI_API_KEY\"] = json.load(open('private_config.json'))['OPENAI_API_KEY']\n",
    "# init the model\n",
    "model = KnowledgeGraphLLM(model_name=\"gpt-3.5-turbo\",\n",
    "                              max_tokens=400)"
   ],
   "id": "b05d842e391b983e",
   "outputs": [],
   "execution_count": 48
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-01T07:30:23.331602Z",
     "start_time": "2024-09-01T07:30:21.602484Z"
    }
   },
   "cell_type": "code",
   "source": [
    "prompt = prompt_template.invoke(\n",
    "            {\"concept\": candidate_concept,\n",
    "             \"graph1\": candidate_subgraph,\n",
    "             \"graph2\": prerequisite_of_graph_subgraph,\n",
    "             \"background\": abstracts,\n",
    "             \"relation_definitions\": '\\n'.join(\n",
    "                 [f\"{rel_type}: {rel_data['description']}\" for rel_type, rel_data in\n",
    "                  relation_def.items()])})\n",
    "\n",
    "# query the model\n",
    "response = model.invoke(prompt)"
   ],
   "id": "cb238fc50c1fd107",
   "outputs": [],
   "execution_count": 49
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-01T07:32:18.356142Z",
     "start_time": "2024-09-01T07:32:18.352543Z"
    }
   },
   "cell_type": "code",
   "source": [
    "for triple in json.loads(response):\n",
    "    print(', '.join([triple['s'], triple['p'], triple['o']]))"
   ],
   "id": "4490fbcbf77fc2ff",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OCR post-correction, Compare, spelling correction\n",
      "spelling correction, Is-a-Prerequisite-of, linguistics basics\n",
      "spelling correction, Is-a-Prerequisite-of, chinese nlp\n",
      "spelling correction, Is-a-Prerequisite-of, natural language processing intro\n"
     ]
    }
   ],
   "execution_count": 53
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
